{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "f293d717",
   "metadata": {},
   "source": [
    "# RICO analysis\n",
    "This notebook qualitatively analyzes learned models in rico dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bb8241e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "1abd9f1a",
   "metadata": {},
   "source": [
    "##### Editable parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f63409a0",
   "metadata": {},
   "outputs": [],
   "source": [
    "ckpt_dir = \"../results/rico/ours-exp-ft/checkpoints\"\n",
    "dataset_name = \"rico\"\n",
    "db_root = \"../data/rico\"\n",
    "batch_size = 4"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "ec3beb72",
   "metadata": {},
   "source": [
    "##### Initialization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3ca99271",
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy\n",
    "import itertools\n",
    "import logging\n",
    "import random\n",
    "import sys\n",
    "\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "from IPython.display import display, HTML\n",
    "%matplotlib inline\n",
    "\n",
    "sys.path.append(\"../src/mfp\")\n",
    "\n",
    "from mfp.models.mfp import MFP, merge_inputs_and_prediction\n",
    "from mfp.models.architecture.mask import get_seq_mask\n",
    "from mfp.models.masking import get_initial_masks\n",
    "from mfp.data import DataSpec\n",
    "from mfp.helpers import svg_rico as svg\n",
    "from util import grouper, load_model\n",
    "\n",
    "logging.basicConfig(level=logging.INFO)\n",
    "logger = logging.getLogger(__name__)\n",
    "\n",
    "# fix seed for debug\n",
    "tf.random.set_seed(0)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "20978d1b",
   "metadata": {},
   "source": [
    "##### Load datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "034361e9",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataspec = DataSpec(dataset_name, db_root, batch_size)\n",
    "test_dataset = dataspec.make_dataset(\"test\", shuffle=False)\n",
    "iterator = iter(test_dataset.take(1))\n",
    "example = next(iterator)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f08c110f",
   "metadata": {},
   "source": [
    "##### Load pre-trained models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "efde3727",
   "metadata": {},
   "outputs": [],
   "source": [
    "input_columns = dataspec.make_input_columns()\n",
    "models = {\"main\": load_model(ckpt_dir, input_columns=input_columns)}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f9db7358",
   "metadata": {},
   "source": [
    "##### Define some helpers for ELEM-filling task"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5fcc5829",
   "metadata": {},
   "outputs": [],
   "source": [
    "builder0 = svg.SVGBuilder(\n",
    "    max_width=128,\n",
    "    max_height=192,\n",
    "    key=\"type\",\n",
    "    preprocessor=dataspec.preprocessor,\n",
    ")\n",
    "\n",
    "# demo for ELEM prediction (randomly mask a single element)\n",
    "def visualize_reconstruction(models, example, dataspec, input_builders, output_builders):\n",
    "    seq_mask = get_seq_mask(example[\"length\"])\n",
    "    mfp_masks = get_initial_masks(input_columns, seq_mask)\n",
    "    example_copy = copy.deepcopy(example)\n",
    "\n",
    "    n_elem = tf.cast(tf.reduce_sum(tf.cast(seq_mask, tf.float32), axis=1), tf.int32).numpy()\n",
    "    target_indices = [random.randint(0, n - 1) for n in n_elem]\n",
    "    indices = []\n",
    "    B, S = example_copy[\"left\"].shape[:2]\n",
    "    for i in range(B):\n",
    "        indices.append([j for j in range(S) if j != target_indices[i]])\n",
    "    indices = tf.convert_to_tensor(np.array(indices))\n",
    "    for key in example_copy.keys():\n",
    "        if example_copy[key].shape[1] > 1:\n",
    "            example_copy[key] = tf.gather(example_copy[key], indices, batch_dims=1)\n",
    "    example_copy[\"length\"] -= 1\n",
    "\n",
    "    svgs = []\n",
    "    for builder in input_builders:\n",
    "        svgs.append(list(map(builder, dataspec.unbatch(example_copy))))\n",
    "\n",
    "    for key in mfp_masks.keys():\n",
    "        if not input_columns[key][\"is_sequence\"]:\n",
    "            continue\n",
    "        dummy = mfp_masks[key].numpy()\n",
    "        for i in range(len(target_indices)):\n",
    "            dummy[i, target_indices[i]] = True  # hide single element for each sample\n",
    "        mfp_masks[key] = tf.convert_to_tensor(dummy)\n",
    "\n",
    "    for model in models:\n",
    "        pred = model(example, training=False, demo_args={\"masks\": mfp_masks})\n",
    "        pred = merge_inputs_and_prediction(example, input_columns, mfp_masks, pred)\n",
    "\n",
    "        for builder in output_builders:\n",
    "            svgs.append(list(map(builder, dataspec.unbatch(pred))))\n",
    "\n",
    "    for builder in input_builders:\n",
    "        svgs.append(list(map(builder, dataspec.unbatch(example))))\n",
    "\n",
    "    return [list(grouper(row, len(input_builders))) for row in zip(*svgs)]\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "706181c4",
   "metadata": {},
   "source": [
    "##### Visualization of results\n",
    "From left to right: input (one element missing), prediction, ground truth"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c420788b",
   "metadata": {},
   "outputs": [],
   "source": [
    "svgs = visualize_reconstruction(models.values(), example, dataspec, [builder0], [builder0])\n",
    "for i, row in enumerate(svgs):\n",
    "    print(i)\n",
    "    display(HTML(\"<div>%s</div>\" % \" \".join(itertools.chain.from_iterable(row))))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.7.12 ('.venv': venv)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.12"
  },
  "vscode": {
   "interpreter": {
    "hash": "b1cc4bcd870fb3eb296f14a2aa1daa467f3c14f214d3fd2c136db8d539a2d2c9"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
